{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "regression.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/wx-chevalier/AIDL-Workbench/blob/master/pytorch/regression/regression.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jfBW5ECVh6qv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        },
        "outputId": "b3a8082e-0cc5-4a74-b5cd-0fad49401d2f"
      },
      "source": [
        "#!/usr/bin/env python\n",
        "from __future__ import print_function\n",
        "from itertools import count\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "\n",
        "POLY_DEGREE = 4\n",
        "W_target = torch.randn(POLY_DEGREE, 1) * 5\n",
        "b_target = torch.randn(1) * 5\n",
        "\n",
        "\n",
        "def make_features(x):\n",
        "    \"\"\"Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].\"\"\"\n",
        "    x = x.unsqueeze(1)\n",
        "    return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)\n",
        "\n",
        "\n",
        "def f(x):\n",
        "    \"\"\"Approximated function.\"\"\"\n",
        "    return x.mm(W_target) + b_target.item()\n",
        "\n",
        "\n",
        "def poly_desc(W, b):\n",
        "    \"\"\"Creates a string description of a polynomial.\"\"\"\n",
        "    result = 'y = '\n",
        "    for i, w in enumerate(W):\n",
        "        result += '{:+.2f} x^{} '.format(w, len(W) - i)\n",
        "    result += '{:+.2f}'.format(b[0])\n",
        "    return result\n",
        "\n",
        "\n",
        "def get_batch(batch_size=32):\n",
        "    \"\"\"Builds a batch i.e. (x, f(x)) pair.\"\"\"\n",
        "    random = torch.randn(batch_size)\n",
        "    x = make_features(random)\n",
        "    y = f(x)\n",
        "    return x, y\n",
        "\n",
        "\n",
        "# Define model\n",
        "fc = torch.nn.Linear(W_target.size(0), 1)\n",
        "\n",
        "for batch_idx in count(1):\n",
        "    # Get data\n",
        "    batch_x, batch_y = get_batch()\n",
        "\n",
        "    # Reset gradients\n",
        "    fc.zero_grad()\n",
        "\n",
        "    # Forward pass\n",
        "    output = F.smooth_l1_loss(fc(batch_x), batch_y)\n",
        "    loss = output.item()\n",
        "\n",
        "    # Backward pass\n",
        "    output.backward()\n",
        "\n",
        "    # Apply gradients\n",
        "    for param in fc.parameters():\n",
        "        param.data.add_(-0.1 * param.grad.data)\n",
        "\n",
        "    # Stop criterion\n",
        "    if loss < 1e-3:\n",
        "        break\n",
        "\n",
        "print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))\n",
        "print('==> Learned function:\\t' + poly_desc(fc.weight.view(-1), fc.bias))\n",
        "print('==> Actual function:\\t' + poly_desc(W_target.view(-1), b_target))"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Loss: 0.000876 after 727 batches\n",
            "==> Learned function:\ty = +5.23 x^4 +4.25 x^3 +3.78 x^2 +7.39 x^1 -3.62\n",
            "==> Actual function:\ty = +5.28 x^4 +4.25 x^3 +3.77 x^2 +7.39 x^1 -3.60\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}